// file: arch/page.c
// autor:jiangxinpeng
// time:2020.12.27
// copyright:(C) 2020-2050 by jiang xinpeng,All right are reserved.

#include <arch/page.h>
#include <arch/pymem.h>
#include <arch/swapswitch.h>
#include <arch/mem_pool.h>
#include <arch/x86.h>
#include <os/swap.h>
#include <os/memspace.h>
#include <os/exception.h>
#include <os/kernelif.h>
#include <os/schedule.h>
#include <os/debug.h>
#include <os/task.h>
#include <os/spinlock.h>
#include <os/swap.h>
#include <os/memcache.h>
#include <lib/type.h>
#include <lib/string.h>
#include <sys/fcntl.h>
#include <lib/unistd.h>

//#define DEBUG_PAGE_MAP

// page operator lock used to smp mode
DEFINE_SPIN_LOCK(page_spin_lock);

// count no page
static uint32_t nopage_count = 0;
// disk file swap out offset
static uint32_t swapout_off = 0;

// indiate page enable status
uint8_t page_enable = 0;

// do handle no page and map page
static int DoNoPage(uint64_t addr, uint64_t prot)
{
        return _MapPage(addr, PAGE_SIZE, prot);
}

// whether virtual page had been bind to pysical page
// return:  1= pde and pte are present
//          0= pde present,pte no present
//         -1= pde and pte are no present
int CheckPagePresent(uint64_t vaddr)
{
        if ((*GetPdeVptr(vaddr) & PAGE_PRESENT) && (*GetPteVptr(vaddr) & PAGE_PRESENT))
                return 1;
        else if ((*GetPdeVptr(vaddr) & PAGE_PRESENT))
                return 0;
        else
                return -1;
}

// do swap page and swap out page from disk
static int DoNoPypage(uint64_t addr)
{
        uint64_t page;
        pde_t *pde = GetPdeVptr(addr);
        pte_t *pte = GetPteVptr(addr);

        if ((*pde & PAGE_PRESENT) && !(*pte & PAGE_PRESENT))
        {
                // alloc a pypage
                if (*pte & PAGE_USERPAGE)
                        page = ADDR(AllocUserPage(1));
                else
                        page = ADDR(AllocKernelPage(1));

                if (!page)
                {
                        KPrint("memory no enough when do page swap in!\n");
                        return -1;
                }
                #if SWAP_ENABLE
                // try to swap in page from file
                SwapIn(addr);
                #endif 
        }
}

// try to write to  no writeable area
static int DoPageNoWrite(uint64_t addr, int kernel)
{
        if (kernel)
        {
                // kernel
                KPrint("page fault at %x\n", addr);
                KPrint("in kernel mode,pid[%d] try to write no permission area!\n", cur_task->pid);
                Panic("System hlt...");
        }
        else
        {
                // user
                KPrint("protect fault at %x\n", addr);
                KPrint("in user mode,pid[%d] try to write no permission area!\n", cur_task->pid);
                ExceptionForceSelf(EXC_CODE_SEGV);
        }
}

// try to read from no readable area
static int DoPageNoRead(uint64_t addr, int kernel)
{

        if (kernel)
        {
                // kernel
                KPrint("page fault at %x\n", addr);
                KPrint("In kernel mode,pid[%d] try to read no readable area!\n", cur_task->pid);
                Panic("System hlt...");
        }
        else
        {
                // user
                KPrint("page fault at %x\n", addr);
                KPrint("In user mode,pid[%d] try to read no readable area!\n", cur_task->pid);
                ExceptionForceSelf(EXC_CODE_SEGV);
        }
        return 0;
}

static int DoProtectionFault(trap_frame_t *frame, uint64_t addr, int write)
{
        if (write)
        {
                KPrint("addr %x have write protection!\n", addr);
                DoPageNoWrite(addr, frame->error_code != PAGE_ERR_USER);
                ExceptionForceSelf(EXC_CODE_SEGV);
                return 0;
        }
        else
        {
                KPrint("addr %x have read protection!\n", addr);
                DoPageNoRead(addr, frame->error_code != PAGE_ERR_USER);
                ExceptionForceSelf(EXC_CODE_SEGV);
                return 0;
        }
}

// page fault handler function
int PageDoFault(trap_frame_t *frame)
{
        task_t *cur = cur_task;
        uint64_t addr = ReadCR2(); // read page fault addr

        KPrint("[page fault] at %x\n", addr);

        // process no have vmm
        if (!(cur->vmm))
        {
                KPrint("process no vmm!\n", addr);
                DumpTrapFrame(frame);
                Panic("halt");
        }

        // increase nopage count
        nopage_count++;

        // page protect fault
        if (frame->error_code == PAGE_ERR_PROTECT)
        {
                // according run mode to do access protect handler
                if (DoProtectionFault(frame, addr, frame->error_code != PAGE_ERR_USER) < 0)
                        return -1;
                return 0;
        }

        // addr in kernel space
        if (!(addr < USER_VMM_BASE_ADDR && addr > USER_STACK_TOP))
        {
                if (frame->error_code == PAGE_ERR_USER) // user mode try to access kernel space
                {
                        KPrint("page fault at %x\n");
                        KPrint("In user mode,try to access no permission address!\n", cur->pid);
                        DumpTrapFrame(frame);
                        ExceptionForceSelf(EXC_CODE_SEGV);
                }
                else
                {
                        // kernel mode access kernel space occured nopage
                        if (CheckPagePresent(addr) <= 0) // page no present
                        {
                                KPrint("page fault at %x\n", addr);
                                KPrint("a memory problem occured at kernel!\n");
                                DumpTrapFrame(frame);
                                Panic("system hlt...\n");
                        }
                        else if (CheckPagePresent(addr) == 0) // pde present but pypage maybe swap to disk
                        {
                                if (DoNoPypage(addr) < 0) // no pypage
                                {
                                        KPrint("page fault at %x\n", addr);
                                        KPrint("a memory problem occured at kernel!\n");
                                        DumpTrapFrame(frame);
                                        Panic("System hlt..\n");
                                }
                        }
                }
        }
        else
        {
                // addr in user space
                mem_space_t *space = MemSpaceFind(cur->vmm, addr);
                if (!space)
                {
                        KPrint("page fault at %x\n", addr);
                        KPrint("pid[%d] try to access unknow space!\n", cur->pid);
                        DumpTrapFrame(frame);
                        ExceptionForceSelf(EXC_CODE_SEGV);
                }

                if (CheckPagePresent(addr) <= 0) // page no present
                {
                        if (DoNoPage(addr, space->protect) < 0) // try to map a page
                        {
                                KPrint("page fault at %x\n", addr);
                                KPrint("pid[%d] access unmap space and try to do page map fault!\n", cur->pid);
                                DumpTrapFrame(frame);
                                ExceptionForceSelf(EXC_CODE_SEGV);
                        }
                }
                else if (CheckPagePresent(addr) == 0) // pde present but pypage maybe swap to disk
                {
                        if (DoNoPypage(addr) < 0) // no pypage
                        {
                                KPrint("page fault at %x\n", addr);
                                KPrint("pid[%d] access unmap space and try to swap out pypage fault!\n", cur->pid);
                                DumpTrapFrame(frame);
                                ExceptionForceSelf(EXC_CODE_SEGV);
                        }
                }
        }
}

// check page if can read
int CheckPageRead(uint32_t vbase, uint32_t len)
{
        uint32_t base = vbase&PAGEBASE_MASK;
        uint32_t total = PAGE_ALIGN(len);
        while (total)
        {
                pte_t *page = GetPteVptr(base);
                if (page != NULL)
                {
                        if (!(*page & PAGE_PRESENT))
                        {
                                return -1;
                        }
                }
                base += PAGE_SIZE;
                total -= PAGE_SIZE;
        }
        return 0;
}

int CheckPageWrite(uint32_t vbase, uint32_t len)
{
        uint32_t base =vbase&PAGEBASE_MASK;
        uint32_t total = PAGE_ALIGN(len);
        while (total)
        {
                pte_t *page = GetPteVptr(base);

                if (CheckPageRead(base, len) < 0)
                        return -1;
                if (!(*page & PAGE_WRITE))
                        return -1;

                base += PAGE_SIZE;
                total -= PAGE_SIZE;
        }
        return 0;
}

// disable remap pypage
int _VbaseLinkPybaseUnsafe(uint32_t vbase, uint32_t pybase, uint32_t attribute)
{
        SpinLock(&page_spin_lock);

        // get pde and pte
        uint32_t *pde = GetPdeVptr(vbase);
        uint32_t *pte = GetPteVptr(vbase);
        uint32_t page_table;

        // if page table tern present
        if (*pde & PAGE_PRESENT)
        {
                // update page
                *pte &= ~PAGEBASE_MASK;
                *pte = pybase | attribute | PAGE_PRESENT;
                // flush buffer
                FlushTBL(vbase);

                SpinUnlock(&page_spin_lock);
                return 0;
        }
        else
        {
                //  if pde no present,do follow
                //  try to alloc a page table
                page_table = (uint32_t)AllocKernelPage(1);

                memset(PAGE_HIGH_BASE + page_table, 0, PAGE_SIZE);

                //  if has alloc successful
                if (page_table)
                {
                        // add pde
                        *pde = page_table | PAGE_PRESENT | PAGE_USERPAGE | PAGE_WRITE; // set page dir to global open

                        // add pte
                        *pte &= ~PAGEBASE_MASK;
                        *pte = pybase | attribute | PAGE_PRESENT;
                        //  flush tbl
                        FlushTBL(vbase);

                        SpinUnlock(&page_spin_lock);
                        return 0;
                }
        }

        SpinUnlock(&page_spin_lock);
        return -1;
}

// disable remap pypage
int _VbaseLinkPybaseSafe(uint32_t vbase, uint32_t pybase, uint32_t attribute)
{
        SpinLock(&page_spin_lock);

        // get pde and pte
        uint32_t *pde = GetPdeVptr(vbase);
        uint32_t *pte = GetPteVptr(vbase);
        uint32_t page_table;

        // if page table tern present
        if (*pde & PAGE_PRESENT)
        {
                // update page , only when page no present
                if (*pte & PAGE_PRESENT)
                {
                        SpinUnlock(&page_spin_lock);
                        return -1;
                }

                *pte &= ~PAGEBASE_MASK;
                *pte = pybase | attribute | PAGE_PRESENT;
                // flush buffer
                FlushTBL(vbase);

                SpinUnlock(&page_spin_lock);
                return 0;
        }
        else
        {
                //  if pde no present,do follow
                //  try to alloc a page table
                page_table = (uint32_t)AllocKernelPage(1);
                //  if has alloc successful
                if (!page_table)
                {
                        SpinUnlock(&page_spin_lock);
                        return -1;
                }

                memset(PAGE_HIGH_BASE + page_table, 0, PAGE_SIZE);

                // add pde
                *pde = page_table | PAGE_PRESENT | PAGE_USERPAGE | PAGE_WRITE; // set page dir to global open

                // add pte
                if (*pte & PAGE_PRESENT)
                {
                        SpinUnlock(&page_spin_lock);
                        return -1;
                }

                *pte &= ~PAGEBASE_MASK;
                *pte = pybase | attribute | PAGE_PRESENT;
                //  flush tbl
                FlushTBL(vbase);

                SpinUnlock(&page_spin_lock);
                return 0;
        }

        SpinUnlock(&page_spin_lock);
        return -1;
}

void _VbaseUnlinkPybase(uint32_t vbase)
{
        uint32_t *pde = GetPdeVptr(vbase);
        uint32_t *pte = GetPteVptr(vbase);

        SpinLock(&page_spin_lock);
        // pde if present
        if (*pde & PAGE_PRESENT)
        {
                // pte if present
                if (*pte & PAGE_PRESENT)
                {
                        // clear pte present flags
                        *pte = *pte & (~PAGE_PRESENT);
                        // flush tbl
                        FlushTBL(vbase);

                }
        }
        SpinUnlock(&page_spin_lock);
}

// map a vbase to py page
int _MapPage(uint32_t startbase, uint32_t len, uint32_t protect)
{
        uint32_t base = startbase&PAGEBASE_MASK; // align to page size
        len=PageAlign(len);
        uint32_t end = startbase + len;
        uint32_t pybase = 0; // alloc page
        uint32_t attr = PAGE_PRESENT;

        #ifdef DEBUG_PAGE_MAP
        KPrint("%s: startbase %x len 0x%x protect %x\n",__func__,startbase,len,protect);
        #endif 

        // attr set
        if (protect & PROTE_KERNEL)
                attr |= KERNEL_PAGE_ATTR;

        if (protect & PROTE_USER)
                attr |= PAGE_USERPAGE;

        if (protect & PROTE_WRITE)
                attr |= PAGE_WRITE;

        // while pybase no above alloc size
        while (base < end)
        {
                pybase = AllocUserPage(1);
                if (!pybase)
                {
                        KPrint("alloc user page failed\n");
                        return -1;
                }
                VbaseLinkPybase(base, pybase, attr);
                // next page base
                base += PAGE_SIZE; // pybase next page
        }
        return 0;
}

// del vbase map
void _UnMapPage(uint32_t startbase, uint32_t 
len, uint32_t fix)
{
        uint32_t base = startbase&PAGEBASE_MASK;
        len=PageAlign(len);
        uint32_t count = len / PAGE_SIZE;
        uint32_t page = 0;

        // check limit
        while (base < (startbase + count * PAGE_SIZE))
        {
                page = (*((pte_t *)GetPteVptr(base))) & PAGEBASE_MASK;

                if (!fix) //no fixed map need to free page on same time
                {
                        // free page
                        FreePage(page);
                }
                // unlink pybase
                _VbaseUnlinkPybase(base);
                // next page
                base += PAGE_SIZE;
        }
}


// map to fixed page
// used to share mem area map
void _MapFixedPage(uint32_t vbase, uint32_t pybase, uint32_t len, uint32_t protect)
{
        uint32_t base_v = vbase&PAGEBASE_MASK;
        uint32_t base_py = pybase;
        uint32_t count = (len + PAGE_SIZE - 1) / PAGE_SIZE;
        uint32_t attr = 0;

        #ifdef DEBUG_PAGE_MAP
        KPrint("map fixed page vbase %x pybase %x len %x\n",vbase,pybase,len);
        #endif 

        // according protect filed set page attr
        if (protect & PROTE_KERNEL)
                attr |= KERNEL_PAGE_ATTR;

        if (protect & PROTE_USER)
                attr |= PAGE_USERPAGE;


        if (protect & PROTE_WRITE)
                attr |= PAGE_WRITE;

        while (base_v < base_v + count * PAGE_SIZE)
        {
                // whether enable remap
                if (protect & PROTE_REMAP)
                {
                        _VbaseLinkPybaseUnsafe(base_v, base_py, attr);
                }
                else
                {
                        _VbaseLinkPybaseSafe(base_v, base_py, attr);
                }
                base_v += PAGE_SIZE;
                base_py += PAGE_SIZE;
        }
}

// map pyaddr on zoom[start,end] map to hight address
void KernelPageMapEarly(uint32_t start, uint32_t end)
{
        // KPrint("map early range: %x-%x\n", start, end);
        uint32_t *pdt = (uint32_t *)PAGEDIR_VBASE;
        uint32_t pde_num = (end - start) / (PTE_MAX * PAGE_SIZE);
        uint32_t pte_num = (end - start) / PAGE_SIZE % PTE_MAX;

        uint32_t *pte_base = (uint32_t *)(PAGETABLE_PYBASE + PAGE_SIZE * PAGETABLE_HAD_MAP);
        uint32_t pde_off = 512 + PAGETABLE_HAD_MAP;

        int i, j;
        for (i = 0; i < pde_num; i++)
        {
                pdt[pde_off + i] = (uint32_t)pte_base | USER_PAGE_ATTR;
                for (j = 0; j < PTE_MAX; j++)
                {
                        pte_base[j] = start | USER_PAGE_ATTR;
                        start += PAGE_SIZE;
                }
                pte_base += PAGE_SIZE;
        }

        if (pte_num > 0)
        {
                pdt[pde_off + i] = (uint32_t)pte_base | USER_PAGE_ATTR;
                for (j = 0; j < pte_num; j++)
                {
                        pte_base[j] = start | USER_PAGE_ATTR;
                        start += PAGE_SIZE;
                }
        }
}

// ready to enable page
void SetupPage()
{
        uint32_t *pdir = (uint32_t *)PAGEDIR_PYBASE;
        uint32_t *ptbl = (uint32_t *)PAGETABLE_PYBASE;
        uint32_t pyaddr = 0;

        // signature page enable
        page_enable = 1;
        KPrint("[page] enable page feature\n");
        // clear page dir
        memset(pdir, 0, PAGE_SIZE);
        int i;
        // fill page table,20MB memory
        for (i = 0; i < 1024 * 5; i++)
        {
                ptbl[i] = pyaddr | USER_PAGE_ATTR;
                pyaddr += PAGE_SIZE;
        }

        // fill page director
        for (i = 0; i < 5; i++)
        {
                pdir[i] = (uint32_t)ptbl | USER_PAGE_ATTR;
                pdir[i + 512] = (uint32_t)ptbl | USER_PAGE_ATTR;
                ptbl += 1024;
        }
        pdir[1023] = (uint32_t)pdir | USER_PAGE_ATTR; // point itself
        // set pdir
        WriteCR3((uint32_t)pdir);
        // enable page
        WriteCR0(ReadCR0() | ENABLE_PAGE);
}

uint64_t KernelPageDirCopyTo()
{
        uint64_t page = AllocKernelPage(1);
        uint32_t *vaddr = KERNEL_PYBASE2VBASE(page);
        uint32_t *kernel_page_dir = PAGEDIR_VBASE;

        // clear new page dir
        memset(vaddr, 0, PAGE_SIZE);

        memcpy((void *)vaddr + PAGE_SIZE / 2, (void *)kernel_page_dir + PAGE_SIZE / 2, PAGE_SIZE / 2);

        // 0-4 MB kernel page
        vaddr[0] = kernel_page_dir[0];

        // last page dir entries point itself
        vaddr[1023] = page | USER_PAGE_ATTR;

        return page;
}